package com.my.service.task.logistic;

import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.util.Collector;

import java.util.ArrayList;
import java.util.Iterator;

public class LogisticReducer implements GroupReduceFunction<LogisticInfo,ArrayList<Double>> {
    @Override
    public void reduce(Iterable<LogisticInfo> iterable, Collector<ArrayList<Double>> collector) throws Exception {
        Iterator<LogisticInfo> iterator = iterable.iterator();
        CreateDataSet dataSet = new CreateDataSet();
        while(iterator.hasNext()){
            LogisticInfo logisticInfo = iterator.next();
            // 构建train set
            String var1 = logisticInfo.getVar1();
            String var2 = logisticInfo.getVar2();
            String var3 = logisticInfo.getVar3();
            String label = logisticInfo.getLabel();
            ArrayList<String> features = new ArrayList<String>();
            features.add(var1);
            features.add(var2);
            features.add(var3);
            dataSet.data.add(features);
            dataSet.labels.add(label);

        }
        ArrayList<Double> weights = Logistic.gradAscent1(dataSet, dataSet.labels, 500);
        collector.collect(weights);
    }
}
